{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Federated learning - Train a reinforcement learning agent in a CartPole   environment\n",
    "\n",
    "This turorial demonstrates training of a reinforcement learning agent using federated learning in a CartPole environment. Before running this program you would need to install OpenAI gym.\n",
    "\n",
    "To train our agent we would be using a policy which uses a simple neural network that maps the CartPole environment's state space to an action space. This policy is trained using federated learning with the help of the Pysyft library. The program simulates that the policy training happens on a remote machine (represented by the remote worker Bob).\n",
    "\n",
    "#### References: 1. [Pytorch Examples](https://github.com/pytorch/examples/tree/master/reinforcement_learning)\n",
    "\n",
    "#### Author: Amit Rastogi   Github: [@amit-rastogi](https://github.com/amit-rastogi)   Twitter: [@amitrastogi](https://twitter.com/amitrastogi2202)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tf_encrypted:Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow (1.13.1). Fix this by compiling custom ops.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "from torch.distributions import Categorical\n",
    "import gym\n",
    "import numpy as np\n",
    "import syft as sy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create CartPole environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('CartPole-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hook Torch and create a virtual remote worker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook = sy.TorchHook(torch)\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implement our neural network policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Policy(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Policy, self).__init__()\n",
    "        self.input = nn.Linear(4, 4)\n",
    "        self.output = nn.Linear(4, 2)\n",
    "\n",
    "        self.episode_log_probs = []\n",
    "        self.episode_raw_rewards = []\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.input(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.output(x)\n",
    "        x = F.softmax(x, dim=1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy = Policy()\n",
    "optimizer = optim.SGD(params=policy.parameters(), lr=0.03)\n",
    "#discount rate to be used for action score calculation\n",
    "discount_rate = 0.95"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_action(state):\n",
    "    state = torch.from_numpy(state).float().unsqueeze(0)\n",
    "    #send the environment state to bob\n",
    "    state = state.send(bob)\n",
    "    probs = policy(state)\n",
    "    #we need to get the estimated probabilities back to sample the action since Categorical does not yet\n",
    "    #support remote tensor operations as of now\n",
    "    probs = probs.get()\n",
    "    m = Categorical(probs)\n",
    "    action = m.sample()\n",
    "    policy.episode_log_probs.append(m.log_prob(action))\n",
    "    #get the state back as we would be sending the new state to bob\n",
    "    state.get()\n",
    "    return action.item()\n",
    "\n",
    "def discount_and_normailze_rewards():\n",
    "    discounted_rewards = []\n",
    "    cumulative_rewards = 0\n",
    "    \n",
    "    for reward in policy.episode_raw_rewards[::-1]:\n",
    "        cumulative_rewards = reward + discount_rate * cumulative_rewards\n",
    "        discounted_rewards.insert(0, cumulative_rewards)\n",
    "    \n",
    "    discounted_rewards = torch.tensor(discounted_rewards)\n",
    "    discounted_rewards = (discounted_rewards - discounted_rewards.mean())/discounted_rewards.std()\n",
    "    \n",
    "    return discounted_rewards\n",
    "\n",
    "def update_policy():\n",
    "    policy_loss = []\n",
    "    discounted_rewards = discount_and_normailze_rewards()\n",
    "    for log_prob, action_score in zip(policy.episode_log_probs, discounted_rewards):\n",
    "        policy_loss.append(-log_prob * action_score)\n",
    "    \n",
    "    optimizer.zero_grad()\n",
    "    policy_loss = torch.cat(policy_loss).sum()\n",
    "    policy_loss.backward()\n",
    "    optimizer.step()\n",
    "    del policy.episode_log_probs[:]\n",
    "    del policy.episode_raw_rewards[:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train our Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average reward: 19.17\tMax reward: 83.00\n"
     ]
    }
   ],
   "source": [
    "total_rewards = []\n",
    "# send the policy to bob for training\n",
    "policy.send(bob)\n",
    "for episode in range(500):\n",
    "    state = env.reset()\n",
    "    episode_rewards = 0\n",
    "    for step in range(1000):\n",
    "        action = select_action(state)\n",
    "        state, reward, done, _ = env.step(action)\n",
    "        #env.render()  #uncomment to render the current environment\n",
    "        policy.episode_raw_rewards.append(reward)\n",
    "        episode_rewards += reward\n",
    "        \n",
    "        if done:\n",
    "            break        \n",
    "    #to keep track of rewards earned in each episode\n",
    "    total_rewards.append(episode_rewards)\n",
    "    update_policy()\n",
    "\n",
    "#cleanup\n",
    "policy.get()\n",
    "bob.clear_objects()\n",
    "print('Average reward: {:.2f}\\tMax reward: {:.2f}'.format(np.mean(total_rewards), np.max(total_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Well Done!\n",
    "\n",
    "Our agent managed to keep the pole upright for a maximum of 83 consecutive steps using a very simple neural network policy trained using federated learning with Pysyft.\n",
    "\n",
    "## Limitations\n",
    "\n",
    "In select_state method we have to get the estimated probabilities back to our local worker to sample the action since Categorical does not support remote tensor operations as of now."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "## Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "## Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "## Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community!\n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org/)\n",
    "\n",
    "## Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked Good First Issue.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "## Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
